# ACC_CONFIG_FILE="configs/acc_single_default.yaml"
ACC_CONFIG_FILE="configs/acc_multi_default.yaml"
# ACC_CONFIG_FILE="configs/acc_multi_deepspeed.yaml"
# ACC_CONFIG_FILE="configs/acc_multi_machine.yaml"
# ACC_CONFIG_FILE="configs/acc_multi_machine_deepspeed.yaml"

NUM_NODE=1
GPU_IDS="0,1,2,3,4,5,6,7"
NGPU_PER_NODE=8 # ! make sure compatible with GPU_IDS
MASTER_ADDR="192.168.16.4"
MASTER_PORT=29500

NUM_PROCESSES=`expr $NUM_NODE \* $NGPU_PER_NODE`

# # * multi machine training
# accelerate launch --config_file $ACC_CONFIG_FILE --num_machines $NUM_NODE --num_processes $NUM_PROCESSES --gpu_ids $GPU_IDS --machine_rank $RANK --main_process_ip $MASTER_ADDR --main_process_port $MASTER_PORT \
# * multi gpu training

lr_list=(1e-3 3e-4 1e-4 3e-5 1e-5 3e-6)
for lr in ${lr_list[@]}
do
	accelerate launch --config_file $ACC_CONFIG_FILE --num_processes $NUM_PROCESSES --gpu_ids $GPU_IDS --main_process_port $MASTER_PORT \
	train.py \
		--tracker_project_name reflow \
		--pretrained_model_name_or_path checkpoints/SD-1-5 \
		--vae_name_or_path checkpoints/sd-vae-ft-mse \
		--unet_from_config configs/unet_config_DeepFloydStyle_medium_256px.json \
		--dataset_name coco2014 \
		--overwrite_dataset \
		--output_dir logs/v1-1-1_$lr \
		--seed=1234 \
		--resolution=256 \
		--random_flip \
		--center_crop \
		--train_batch_size=64 \
		--max_train_steps=20000 \
		--gradient_accumulation_steps=1 \
		--learning_rate=$lr \
		--lr_scheduler "constant_with_warmup" \
		--lr_warmup_steps=1000 \
		--dataloader_num_workers=8 \
		--mixed_precision="fp16" \
		--adam_weight_decay=1e-3 \
		--checkpointing_steps=5000 \
		--use_ema \
		--gradient_checkpointing \
		--enable_xformers_memory_efficient_attention \
		--loss_type l1,l2 \
		--reduction_method mean \
		--p_uncond=0.1 \
		--t_schedule uniform \


		# --pixel_loss \
		# --resume_from_checkpoint checkpoint-30 \
		# --load_pretrained_weights logs/test/weights-50 \
		# --load_in_8bit \
		# --use_8bit_adam \
		# --llm_name_or_path bert-base-uncased \
		# --snr_gamma=5.0 \
done